import numpy as np
import matplotlib.pyplot as plt
from scipy.special import wofz
from matplotlib.colors import LogNorm
import matplotlib.cm as cm

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "font.size": 14,
    "axes.titlesize": 16,
    "axes.labelsize": 14,
})

def f_norm(Omega):
    Om = np.asarray(Omega, dtype=np.float64)
    absOm = np.abs(Om)
    absOm = np.where(absOm == 0.0, 1e-16, absOm)
    denom = np.sqrt(8.0 * np.pi * absOm * np.sinh(np.pi * absOm))
    return np.exp(-0.5 * np.pi * Om) / denom

def M_unified(Omega, Omega_p, chi, chi_p, a, T, omega):
    g = 1.0
    hbar = 1.0
    Omega   = np.asarray(Omega,   dtype=np.float64)
    Omega_p = np.asarray(Omega_p, dtype=np.float64)
    prefactor = (g**2 * a**2) / (2.0 * hbar**2) * chi * chi_p * Omega * Omega_p
    f_product = f_norm(-chi * Omega) * f_norm(-chi_p * Omega_p)
    phase_factor = np.exp(1j * (Omega + Omega_p) * np.log(a))
    gaussian_term = np.exp(-0.25 * (a**2) * (T**2) * (chi*Omega + chi_p*Omega_p)**2)
    faddeeva_arg = T * (0.5 * a * (chi*Omega - chi_p*Omega_p) - omega)
    faddeeva_term = wofz(faddeeva_arg.astype(np.complex128))
    return prefactor * phase_factor * f_product * gaussian_term * faddeeva_term

# Parameters
omega = 1.0
a_val = 1.0   # fixed acceleration
axis_range = 3
n_points = 400

# Four T values (→ four Tω values)
T_values = [0.5, 1.0, 2.0, 5.0]
labels_T = [fr"$\mathbf{{T\omega = {T*omega:.1f}}}$" for T in T_values]

fig, axs = plt.subplots(2, 2, figsize=(12, 10))

for i, (T, label_T) in enumerate(zip(T_values, labels_T)):
    ax = axs.flat[i]

    Omega_vals = np.linspace(-axis_range, axis_range, n_points)
    Op_vals    = np.linspace(-axis_range, axis_range, n_points)
    O_grid, Op_grid = np.meshgrid(Omega_vals, Op_vals, indexing='xy')

    Z_RR = M_unified(O_grid, Op_grid, 1, 1, a=a_val, T=T, omega=omega)
    Prob_RR = np.abs(Z_RR)**2

    Prob_masked = np.ma.masked_less_equal(Prob_RR, 0.0)
    Prob_masked = np.ma.masked_invalid(Prob_masked)

    vmax = np.nanmax(Prob_masked)
    if not np.isfinite(vmax):
        vmax = 1.0
    vmin = vmax * 1e-6

    my_cmap = cm.get_cmap('jet').copy()
    my_cmap.set_under('darkblue')
    my_cmap.set_bad('darkblue')

    log_levels = np.logspace(np.log10(vmin), np.log10(vmax), num=100)
    ax.set_facecolor('darkblue')

    im = ax.contourf(
        O_grid, Op_grid, Prob_masked,
        levels=log_levels,
        cmap=my_cmap,
        norm=LogNorm(vmin=vmin, vmax=vmax),
        extend='min'
    )

    ax.set_title(fr"$\mathbf{{a = {a_val}}}$, {label_T}", fontsize=14)

    # Axis labels
    if i in [0, 2]:
        ax.set_ylabel(r"$\mathbf{\Omega'}$")
    if i in [2, 3]:
        ax.set_xlabel(r"$\mathbf{\Omega}$")

    ax.axhline(0, color='white', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.axvline(0, color='white', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.set_aspect('equal')

# Shared colorbar
fig.subplots_adjust(right=0.87)
cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
fig.colorbar(im, cax=cbar_ax, label=r'$\mathbf{|\mathcal{M}|^2}$')

fig.suptitle(r"\textbf{RR Emission Spectrum with Fixed $a=1$ and Varying $T\omega$}", fontsize=18, y=0.98)

output_filename = "fig_RR_fixed_a_Tomega.png"
plt.savefig(output_filename, dpi=300, bbox_inches='tight')
print(f"✅ Saved as {output_filename}")
